{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using SageMaker debugger to monitor attentions in BERT model training\n",
    "\n",
    "[BERT](https://arxiv.org/abs/1810.04805) is a deep bidirectional transformer model that achieves state-of the art results in NLP tasks like question answering, text classification and others.\n",
    "\n",
    "The paper [Visualizing Attention in Transformer-Based Language Representation Models [1]](https://arxiv.org/pdf/1904.02679.pdf) shows that plotting attentions and individual neurons in the query and key vectors can help to identify causes of incorrect model predictions.\n",
    "With SageMaker Debugger we can easily retrieve those tensors and plot them in real-time as training progresses which may help to understand what the model is learning. \n",
    "\n",
    "The animation below shows the attention scores of the first 20 input tokens for the first 10 iterations in the training.\n",
    "\n",
    "<img src='images/attention_scores.gif' width='350' /> \n",
    "Fig. 1: Attention scores of the first head in the 7th layer \n",
    "\n",
    "[1] *Visualizing Attention in Transformer-Based Language Representation Models*:  Jesse Vig, 2019, 1904.02679, arXiv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get tensors and visualize BERT model training in real-time\n",
    "In this section, we will retrieve the tensors of our training job and create the attention-head view and neuron view as described in [Visualizing Attention in Transformer-Based Language Representation Models [1]](https://arxiv.org/pdf/1904.02679.pdf).\n",
    "\n",
    "First we create the [trial](https://github.com/awslabs/sagemaker-debugger/blob/master/docs/analysis.md#Trial) that points to the tensors in S3:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "from smdebug.trials import create_trial\n",
    "trial = create_trial( '/tmp/tensors' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in trial.tensor_names():\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we import a script that implements the visualization for attentation head view in Bokeh."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import attention_head_view, neuron_view\n",
    "from ipywidgets import interactive"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the validation phase started, we can retrieve the tensors from S3. In particular we are interested in outputs of the attention cells which gives the attention score. First we get the tensor names of the attention scores:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensor_names = []\n",
    "\n",
    "for tname in sorted(trial.tensor_names(regex='.*attention.dropout_output_0')):\n",
    "    tensor_names.append(tname)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we iterate over the available tensors of the validation phase. We retrieve tensor values with `trial.tensor(tname).value(step, modes.EVAL)`. Note: if training is still in progress, not all steps will be available yet. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "steps = trial.steps()\n",
    "tensors = {}\n",
    "\n",
    "for step in steps:\n",
    "    print(\"Reading tensors from step\", step)\n",
    "    for tname in tensor_names: \n",
    "        if tname not in tensors:\n",
    "            tensors[tname]={}\n",
    "        tensors[tname][step] = trial.tensor(tname).value(step)\n",
    "num_heads = tensors[tname][step].shape[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we get the query and key output tensor names:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layers = []\n",
    "layer_names = {}\n",
    "\n",
    "for index, (key, query) in enumerate(zip(trial.tensor_names(regex='.*k_lin_output_0'), trial.tensor_names(regex='.*q_lin_output_0'))):\n",
    "    layers.append([key,query])\n",
    "    layer_names[key.split('_')[0]] = index"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We also retrieve the string representation of the input tokens that were input into our model during validation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_tokens = trial.tensor('input_tokens').value(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Attention Head View\n",
    "\n",
    "The attention-head view shows the attention scores between different tokens. The thicker the line the higher the score. For demonstration purposes, we will limit the visualization to the first 20 tokens. We can select different attention heads and different layers. As training progresses attention scores change and we can check that by selecting a different step. \n",
    "\n",
    "**Note:** The following cells run fine in Jupyter. If you are using JupyterLab and encounter issues with the jupyter widgets (e.g. dropdown menu not displaying), check the subsection in the end of the notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_tokens = 20\n",
    "view = attention_head_view.AttentionHeadView(input_tokens, \n",
    "                                             tensors,  \n",
    "                                             step=trial.steps()[0],\n",
    "                                             layer='distilbert.transformer.layer.0.attention.dropout_output_0',\n",
    "                                             n_tokens=n_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "interactive(view.select_layer, layer=tensor_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "interactive(view.select_head, head=np.arange(num_heads))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following code cell updates the dictionary `tensors`  with the latest tensors from the training the job. Once the dict is updated we can go to above code cell `attention_head_view.AttentionHeadView` and re-execute this and subsequent cells in order to plot latest attentions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_steps = trial.steps()\n",
    "new_steps = list(set(all_steps).symmetric_difference(set(steps)))\n",
    "\n",
    "for step in new_steps: \n",
    "    for tname in tensor_names:  \n",
    "        if tname not in tensors:\n",
    "            tensors[tname]={}\n",
    "        tensors[tname][step] = trial.tensor(tname).value(step)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Neuron view\n",
    "\n",
    "To create the neuron view as described in paper [Visualizing Attention in Transformer-Based Language Representation Models [1]](https://arxiv.org/pdf/1904.02679.pdf), we need to retrieve the queries and keys from the model. The tensors are reshaped and transposed to have the shape: *batch size, number of attention heads, sequence length, attention head size*\n",
    "\n",
    "**Note:** The following cells run fine in Jupyter. If you are using JupyterLab and encounter issues with the jupyter widgets (e.g. dropdown menu not displaying), check the subsection in the end of the notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "queries = {}\n",
    "steps = trial.steps()\n",
    "\n",
    "for step in steps:\n",
    "    print(\"Reading tensors from step\", step)\n",
    "    \n",
    "    for tname in trial.tensor_names(regex='.*q_lin_output_0'):\n",
    "       query = trial.tensor(tname).value(step)\n",
    "       query = query.reshape((query.shape[0], query.shape[1], num_heads, -1))\n",
    "       query = query.transpose(0,2,1,3)\n",
    "       if tname not in queries:\n",
    "            queries[tname] = {}\n",
    "       queries[tname][step] = query"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Retrieve the key vectors:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "keys = {}\n",
    "steps = trial.steps()\n",
    "\n",
    "for step in steps:\n",
    "    print(\"Reading tensors from step\", step)\n",
    "    \n",
    "    for tname in trial.tensor_names(regex='.*k_lin_output_0'):\n",
    "       key = trial.tensor(tname).value(step)\n",
    "       key = key.reshape((key.shape[0], key.shape[1], num_heads, -1))\n",
    "       key = key.transpose(0,2,1,3)\n",
    "       if tname not in keys:\n",
    "            keys[tname] = {}\n",
    "       keys[tname][step] = key"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now select different query vectors and see how they produce different attention scores. We can also select different steps to see how attention scores, query and key vectors change as training progresses. The neuron view shows:\n",
    "* Query\n",
    "* Key\n",
    "* Query x Key (element wise product)\n",
    "* Query * Key (dot product)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "view = neuron_view.NeuronView(input_tokens, \n",
    "                              keys=keys, \n",
    "                              queries=queries, \n",
    "                              layers=layers, \n",
    "                              step=trial.steps()[0], \n",
    "                              n_tokens=n_tokens,\n",
    "                              layer_names=layer_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "interactive(view.select_query, query=np.arange(n_tokens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "interactive(view.select_layer, layer=layer_names.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Note: Jupyter widgets in JupyterLab\n",
    "\n",
    "If you encounter issues with this notebook in JupyterLab, you may have to install JupyterLab extensions. You can do this by defining a SageMaker [Lifecycle configuration](https://docs.aws.amazon.com/sagemaker/latest/dg/notebook-lifecycle-config.html). A lifecycle configuration is a shell script that runs when you either create a notebook instance or whenever you start an instance. You can create a Lifecycle configuration directly in the SageMaker console (more details [here](https://aws.amazon.com/blogs/machine-learning/customize-your-amazon-sagemaker-notebook-instances-with-lifecycle-configurations-and-the-option-to-disable-internet-access/)) When selecting `Start notebook`, copy and paste the following code. Once the configuration is created attach it to your notebook instance and start the instance.\n",
    "\n",
    "```sh\n",
    "#!/bin/bash\n",
    "\n",
    "set -e\n",
    "\n",
    "# OVERVIEW\n",
    "# This script installs a single jupyter notebook extension package in SageMaker Notebook Instance\n",
    "# For more details of the example extension, see https://github.com/jupyter-widgets/ipywidgets\n",
    "\n",
    "sudo -u ec2-user -i <<'EOF'\n",
    "\n",
    "# PARAMETERS\n",
    "PIP_PACKAGE_NAME=ipywidgets\n",
    "EXTENSION_NAME=widgetsnbextension\n",
    "\n",
    "source /home/ec2-user/anaconda3/bin/activate JupyterSystemEnv\n",
    "\n",
    "pip install $PIP_PACKAGE_NAME\n",
    "jupyter nbextension enable $EXTENSION_NAME --py --sys-prefix\n",
    "jupyter labextension install @jupyter-widgets/jupyterlab-manager\n",
    "# run the command in background to avoid timeout \n",
    "nohup jupyter labextension install @bokeh/jupyter_bokeh &\n",
    "\n",
    "source /home/ec2-user/anaconda3/bin/deactivate\n",
    "\n",
    "EOF\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_mxnet_p36",
   "language": "python",
   "name": "conda_mxnet_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
